package com.ruoyi.websocket.handler;

import com.alibaba.fastjson2.JSON;
import com.alibaba.fastjson2.JSONObject;
import com.ruoyi.websocket.utils.WebSocketUserSessionUtil;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.*;

import org.springframework.web.socket.handler.TextWebSocketHandler;

import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

@Component
public class ChatRoomUserIdWebsocketHandler extends TextWebSocketHandler {

    // 可选：缓存用户昵称
    private final Map<String, String> userNickMap = new ConcurrentHashMap<>();

    @Override
    public void afterConnectionEstablished(WebSocketSession session) throws Exception {
        String userId = (String) session.getAttributes().get("userId");
        String nickName = (String) session.getAttributes().get("nickName");

        WebSocketUserSessionUtil.addSession(userId, session);
        userNickMap.put(userId, nickName);

        System.out.println("用户连接：" + nickName + "（" + userId + "）");

        // 通知所有人：用户上线
        broadcastJsonMessage("system", "broadcast", "【广播信息】" + nickName + " 已连接");

        broadcastUserList();
    }

    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
        String userId = (String) session.getAttributes().get("userId");
        String nickName = (String) session.getAttributes().get("nickName");

        WebSocketUserSessionUtil.removeSession(session);
        userNickMap.remove(userId);
        System.out.println("【广播信息】" + nickName + " 已下线");
        broadcastJsonMessage("system", "broadcast", "【广播信息】" + nickName + " 已下线");
        broadcastUserList();
    }

    @Override
    public void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
        JSONObject json = JSON.parseObject(message.getPayload());
        String type = json.getString("type");
        String content = json.getString("content");
        String toUserId = json.getString("toUserId");

        String fromUserId = (String) session.getAttributes().get("userId");
        String fromNickName = (String) session.getAttributes().get("nickName");

        System.out.println(fromNickName + "（" + fromUserId + "）发送消息：" + content);

        switch (type) {
            case "broadcast":
                for (WebSocketSession s : WebSocketUserSessionUtil.getAllSessions()) {
                    if (s.isOpen()) {
                        sendJsonMessage(s, fromUserId, fromNickName, "broadcast", content);
                    }
                }
                break;

            case "private":
                WebSocketSession toSession = WebSocketUserSessionUtil.getSession(toUserId);
                if (toSession != null && toSession.isOpen()) {
                    sendJsonMessage(toSession, fromUserId, fromNickName, "private", content);
                    sendJsonMessage(session, fromUserId, fromNickName, "private", content); // 回显自己
                } else {
                    sendSystemMessage(session, "用户 " + toUserId + " 不在线");
                }
                break;

            case "ping":
                // todo 心跳机制
                break;

            default:
                sendSystemMessage(session, "未知消息类型：" + type);
        }
    }

    private void broadcastUserList() {
        List<JSONObject> users = WebSocketUserSessionUtil.getAllSessions().stream().map(session -> {
            String userId = (String) session.getAttributes().get("userId");
            String nickName = (String) session.getAttributes().get("nickName");
            JSONObject obj = new JSONObject();
            obj.put("userId", userId);
            obj.put("nickName", nickName);
            return obj;
        }).collect(Collectors.toList());

        JSONObject payload = new JSONObject();
        payload.put("type", "userList");
        payload.put("users", users);

        TextMessage userListMessage = new TextMessage(payload.toJSONString());

        for (WebSocketSession s : WebSocketUserSessionUtil.getAllSessions()) {
            if (s.isOpen()) {
                try {
                    s.sendMessage(userListMessage);
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }
    }

    private void sendJsonMessage(WebSocketSession session, String fromUserId, String fromNickName, String type, String content) {
        JSONObject json = new JSONObject();
        json.put("type", "message");
        json.put("fromUserId", fromUserId);
        json.put("fromNickName", fromNickName);
        json.put("msgType", type);
        json.put("content", content);
        try {
            session.sendMessage(new TextMessage(json.toJSONString()));
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    private void broadcastJsonMessage(String fromUserId, String type, String content) {
        for (WebSocketSession s : WebSocketUserSessionUtil.getAllSessions()) {
            if (s.isOpen()) {
                sendJsonMessage(s, fromUserId, "系统", type, content);
            }
        }
    }

    private void sendSystemMessage(WebSocketSession session, String content) {
        sendJsonMessage(session, "system", "系统", "system", content);
    }

}
